{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BojqmYOsPk0A"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "FtMmJ-pvPfNl"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IyATsAVlPz1W"
      },
      "source": [
        "# Inference with Gemma 2 using mistral.rs\n",
        "\n",
        "[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open-source language models from Google. Built from the same research and technology used to create the Gemini models, Gemma models are text-to-text, decoder-only large language models (LLMs), available in English, with open weights, pre-trained variants, and instruction-tuned variants.\n",
        "Gemma models are well-suited for various text-generation tasks, including question-answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop, or your cloud infrastructure, democratizing access to state-of-the-art AI models and helping foster innovation for everyone.\n",
        "\n",
        "[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a versatile framework for Large Language Model (LLM) inference supporting text-to-text and multimodal LLMs. It offers features like grammar support for structured outputs, and inference using LoRA-fine-tuned models, making it a powerful tool for a wide range of AI applications.\n",
        "\n",
        "In this notebook, you will learn how to prompt the Gemma 2 model in various ways using the **mistral.rs** Python APIs in a Google Colab environment.\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Using_with_mistral_rs.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nFgaq_--Qg-O"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you can use the CPU runtime in Colab. Since the default accelerator for any Colab runtime is CPU, simply click the **Connect** button in the top-right corner of the Colab window.\n",
        "\n",
        "### Setup Hugging Face\n",
        "\n",
        "**Before you dive into the tutorial, let's get you set up with Hugging face:**\n",
        "\n",
        "1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n",
        "\n",
        "2. **Hugging Face Token:**  Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bLEUJYZ8QmGz"
      },
      "source": [
        "### Configure your HF token\n",
        "Add your Hugging Face token to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create a new secret with the name `HF_TOKEN`.\n",
        "3. Copy/paste your HF token key into the Value input box of `HF_TOKEN`.\n",
        "4. Toggle the button on the left to allow notebook access to the secret."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hK-qUiQGQbe5"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5DrB-trDQsgE"
      },
      "source": [
        "### Install dependencies\n",
        "\n",
        "First, you must install the Python package for mistral.rs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LLjxxhk2Qrf_"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting mistralrs==0.3.2\n",
            "  Downloading mistralrs-0.3.2-cp310-cp310-manylinux_2_34_x86_64.whl.metadata (1.7 kB)\n",
            "Downloading mistralrs-0.3.2-cp310-cp310-manylinux_2_34_x86_64.whl (14.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.4/14.4 MB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: mistralrs\n",
            "Successfully installed mistralrs-0.3.2\n"
          ]
        }
      ],
      "source": [
        "!pip install mistralrs==0.3.2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_wsT7Eb_1esZ"
      },
      "source": [
        "## Overview\n",
        "\n",
        " In this notebook, you will explore how to prompt the Gemma 2 model using the Python APIs of mistral.rs. It's divided into the following sections:\n",
        "\n",
        "1. Load the Gemma 2 model\n",
        "2. Response generation\n",
        "3. Non-streaming chat completion\n",
        "4. Streaming chat completion\n",
        "5. Inference with grammar\n",
        "\n",
        "Additionally, you can run inference on Gemma 2 using the Rust APIs and command-line interface of mistral.rs. Read more about them in the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#get-started-fast-).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tMfQBPRYsknZ"
      },
      "source": [
        "## Load the Gemma 2 model\n",
        "\n",
        "In this section, you will learn how to load the Gemma 2 model from Hugging Face Hub using the mistral.rs Python APIs.\n",
        "\n",
        "First, create an instance of the `Which` class specifying the model to load. Set `model_id` to the Hugging Face Hub repository ID of the desired Gemma 2 model variant. Specify `Architecture.Gemma2` for the arch parameter.\n",
        "\n",
        "Create an instance of the `Runner` class by providing the created `Which` instance. Set `token_source` to `env:HF_TOKEN` to indicate downloading the model from Hugging Face Hub using your existing Hugging Face token.\n",
        "\n",
        "Models can also be loaded from the local file system. Refer to the [mistral.rs get models guide](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#getting-models) for more details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VYFDjCBBtBhS"
      },
      "outputs": [],
      "source": [
        "from mistralrs import Runner, Which, Architecture\n",
        "\n",
        "runner = Runner(\n",
        "    which=Which.Plain(\n",
        "        model_id=\"google/gemma-2-2b-it\",\n",
        "        arch=Architecture.Gemma2,\n",
        "    ),\n",
        "    token_source=\"env:HF_TOKEN\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YoXapKI3vE4m"
      },
      "source": [
        "## Response generation\n",
        "\n",
        "Next, you'll send a prompt to Gemma 2 for inference using mistral.rs.\n",
        "\n",
        "To achieve this, create an instance of the `CompletionRequest` class, specifying your desired prompt in the `prompt` parameter. Set the model parameter to \"gemma2\". mistral.rs allows you to configure common LLM parameters like `top_p`, `max_tokens`, and `temperature` within `CompletionRequest`. You can find the full list in the [class definition of `CompletionRequest`](https://github.com/EricLBuehler/mistral.rs/blob/458dc5f447161904ee9191fdec1dc5d4f039af54/mistralrs-pyo3/mistralrs.pyi#L44).\n",
        "\n",
        "Finally, call the `send_completion_request` method on the `runner` object, passing the newly created `CompletionRequest` instance as the `request` argument."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "niIzVlPbeSvO"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "A black hole is a region of spacetime where gravity is so strong that nothing, not even light, can escape. \n",
            "\n",
            "Here's a breakdown:\n",
            "\n",
            "**What causes a black hole?**\n",
            "\n",
            "* **Stellar Collapse:** When a massive star runs out of fuel, it collapses under its own gravity. This collapse can create a black hole if the star's mass is enough.\n",
            "* **Supermassive Black Holes:** These are found at the centers of most galaxies, including our own Milky Way. Their formation is still a mystery, but they likely formed from the merging of smaller black holes or from the collapse of massive gas clouds.\n",
            "\n",
            "**Key Features of a Black Hole:**\n",
            "\n",
            "* **Event Horizon:** This is the boundary around a black hole beyond which escape is impossible. Once something crosses the event horizon, it's trapped forever.\n",
            "* **Singularity:** This is the theoretical point at the center of a black hole where all the matter is compressed into an infinitely small point.\n",
            "* **Gravitational Pull:** Black holes have immense gravitational pull, which warps spacetime around them.\n",
            "\n",
            "**How do we know black holes exist?**\n",
            "\n",
            "* **Gravitational Effects:** We can detect black holes by observing their gravitational effects on nearby stars and gas.\n",
            "* **X-ray Emissions:** As matter falls into a black hole, it heats up and emits X-rays, which can be detected by telescopes.\n",
            "* **Gravitational Waves:** When black holes collide, they create ripples in spacetime called gravitational waves, which can be detected by specialized instruments.\n",
            "\n",
            "**Interesting Facts:**\n",
            "\n",
            "* Black holes are not actually \"holes\" in space, but rather regions of extreme density and gravity.\n",
            "* The size of a black hole is measured by its \"Schwarzschild radius,\" which is the distance from the center at which the escape velocity equals the speed of light.\n",
            "* Black holes are still a subject of active research, and scientists are constantly learning more about them.\n",
            "\n",
            "\n",
            "Let me know if you have any other questions! \n",
            "\n",
            "Usage {\n",
            "    completion_tokens: 415,\n",
            "    prompt_tokens: 7,\n",
            "    total_tokens: 422,\n",
            "    avg_tok_per_sec: 1.3271735,\n",
            "    avg_prompt_tok_per_sec: 1.4326648,\n",
            "    avg_compl_tok_per_sec: 1.3255271,\n",
            "    total_time_sec: 317.969,\n",
            "    total_prompt_time_sec: 4.886,\n",
            "    total_completion_time_sec: 313.083,\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "from mistralrs import CompletionRequest\n",
        "\n",
        "request = CompletionRequest(\n",
        "    model=\"gemma2\",\n",
        "    prompt=\"What is a black hole?\",\n",
        "    max_tokens=512,\n",
        "    top_p=0.1,\n",
        "    temperature=0.1,\n",
        ")\n",
        "\n",
        "response = runner.send_completion_request(request)\n",
        "\n",
        "print(response.choices[0].text)\n",
        "print(response.usage)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vu1jh_nLMXyS"
      },
      "source": [
        "## Non-streaming chat completion\n",
        "\n",
        "In addition to basic prompting, mistral.rs also supports multi-turn conversations with LLMs.\n",
        "\n",
        "For multi-turn conversations, you'll maintain a list of dictionaries, each representing a single message in the conversation history with Gemma 2. These dictionaries should include a key named `role` specifying whether the user or the assistant (model) generated the message.\n",
        "\n",
        "Create an instance of `ChatCompletionRequest` by passing the conversation history to the `messages` parameter. Then, invoke the `send_chat_completion_request` method of the `runner` with the newly created `ChatCompletionRequest` instance as the `request` argument.\n",
        "\n",
        "You can also configure any other model parameters you want to use. Refer to the the [class definition of `ChatCompletionRequest`](https://github.com/EricLBuehler/mistral.rs/blob/458dc5f447161904ee9191fdec1dc5d4f039af54/mistralrs-pyo3/mistralrs.pyi#L11)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ffjhPyHpvEJM"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Neil Armstrong and Buzz Aldrin were both American astronauts. They represented the **United States** during the Apollo 11 mission. \n",
            "\n",
            "Usage {\n",
            "    completion_tokens: 29,\n",
            "    prompt_tokens: 74,\n",
            "    total_tokens: 103,\n",
            "    avg_tok_per_sec: 2.808376,\n",
            "    avg_prompt_tok_per_sec: 4.5204644,\n",
            "    avg_compl_tok_per_sec: 1.4281493,\n",
            "    total_time_sec: 36.676,\n",
            "    total_prompt_time_sec: 16.37,\n",
            "    total_completion_time_sec: 20.306,\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "from mistralrs import ChatCompletionRequest\n",
        "\n",
        "messages = [\n",
        "    {\n",
        "        \"role\": \"user\",\n",
        "        \"content\": \"Who are the first humans to land on moon?\",\n",
        "    },\n",
        "    {\n",
        "        \"role\": \"assistant\",\n",
        "        \"content\": \"The first humans to land on the Moon were **Neil Armstrong and Buzz Aldrin** of the Apollo 11 mission on **July 20, 1969**.\",\n",
        "    },\n",
        "    {\n",
        "        \"role\": \"user\",\n",
        "        \"content\": \"Which country did they belong to?\",\n",
        "    },\n",
        "]\n",
        "\n",
        "request = ChatCompletionRequest(\n",
        "    model=\"gemma2\",\n",
        "    messages=messages,\n",
        "    max_tokens=256,\n",
        "    presence_penalty=1.0,\n",
        "    top_p=0.1,\n",
        "    temperature=0.1,\n",
        ")\n",
        "\n",
        "response = runner.send_chat_completion_request(request)\n",
        "\n",
        "print(response.choices[0].message.content)\n",
        "print(response.usage)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7W-F2n_xhWRl"
      },
      "source": [
        "Notice how the model generated the response for the last query from the user based on the previous conversation history."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qZehF79UOVWg"
      },
      "source": [
        "## Streaming chat completion\n",
        "\n",
        "mistral.rs also supports obtaining streamed responses from the model during multi-turn conversations. To enable streaming, set the `stream` parameter of `ChatCompletionRequest` to `True`.\n",
        "\n",
        "The `send_chat_completion_request` function will return an iterable object. You can access each chunk of the streamed response by iterating over this object.\n",
        "\n",
        "Pass the conversation history (`messages`) you created earlier to the `ChatCompletionRequest` instance with `stream=True`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nxNg6cXFog29"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Neil Armstrong and Buzz Aldrin were both American astronauts. They represented the **United States** during the Apollo 11 mission. \n"
          ]
        }
      ],
      "source": [
        "request = ChatCompletionRequest(\n",
        "    model=\"gemma2\",\n",
        "    messages=messages,\n",
        "    max_tokens=256,\n",
        "    presence_penalty=1.0,\n",
        "    top_p=0.1,\n",
        "    temperature=0.1,\n",
        "    stream=True,\n",
        ")\n",
        "\n",
        "response = runner.send_chat_completion_request(request)\n",
        "\n",
        "for chunk in response:\n",
        "    print(chunk.choices[0].delta.content, end=\"\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3RALnmwQP6WQ"
      },
      "source": [
        "## Inference with grammar\n",
        "\n",
        "Specifying a grammar when creating a `ChatCompletionRequest` or `CompletionRequest` allows you to constrain the model's output to a specific format.\n",
        "\n",
        "To ensure the model generates responses that match a regular expression, set the parameter `grammar_type` to \"regex\" and `grammar` to the desired regular expression string.\n",
        "\n",
        "The following code snippet illustrates how to restrict the model's response to two-digit numbers using regular expression grammar:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gi2FGoXJv-K9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "13\n",
            "Usage {\n",
            "    completion_tokens: 5,\n",
            "    prompt_tokens: 31,\n",
            "    total_tokens: 36,\n",
            "    avg_tok_per_sec: 3.115804,\n",
            "    avg_prompt_tok_per_sec: 3.5123498,\n",
            "    avg_compl_tok_per_sec: 1.8328446,\n",
            "    total_time_sec: 11.554,\n",
            "    total_prompt_time_sec: 8.826,\n",
            "    total_completion_time_sec: 2.728,\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "request = CompletionRequest(\n",
        "    model=\"gemma2\",\n",
        "    prompt=\"What is the next number in the fibonnaci series: 1, 1, 2, 3, 5, 8 ?\",\n",
        "    grammar_type=\"regex\",\n",
        "    grammar=r\"\\d{2}\",\n",
        "    top_p=0.1,\n",
        "    temperature=0.1,\n",
        ")\n",
        "\n",
        "response = runner.send_completion_request(request)\n",
        "\n",
        "print(response.choices[0].text)\n",
        "print(response.usage)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X72NYX92nv40"
      },
      "source": [
        "mistral.rs offers support for various grammar types beyond regular expressions. Refer to the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs/tree/master?tab=readme-ov-file#description) for more details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vlRAC8OFa1pQ"
      },
      "source": [
        "These are just a few examples of how you can perform inference with Gemma 2 using mistral.rs. To explore its capabilities further, you can refer to the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs/tree/master?tab=readme-ov-file#--mistralrs).\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "[Gemma_2]Using_with_mistral_rs.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
